"""
@Project    : cosmo-face
@Module     : rfbnet_face.py
@Author     : HuangJiWen[huangjiwen@haier.com]
@Created    : 2020/8/17 15:42
@Desc       : centernet rfb
"""

import torch.nn as nn

from models.center_face.head import KeypointHead
from models.center_face.rfbnet import RFB


class CenterRFBNetFace(nn.Module):

    def __init__(self, cfg):
        super(CenterRFBNetFace, self).__init__()

        self.backbone_model = RFB()
        self.head_model = KeypointHead(cfg['in_channels'], cfg["head_conv"], cfg["num_keypoints"])

    def forward(self, x):
        """x -> torch.Size([2, 3, 608, 608])"""
        x = self.backbone_model(x)  # torch.Size([2, 80, 152, 152])
        return self.head_model(x)
